#include "PrincipalComponentAnalysisMotionSegmenter.h"
#include <MMM/Motion/Plugin/KinematicPlugin/KinematicSensor.h>
#include <MMM/Motion/Plugin/MoCapMarkerPlugin/MoCapMarkerSensor.h>
#include <MMM/Motion/Plugin/ModelPosePlugin/ModelPoseSensor.h>
#include <VirtualRobot/RobotConfig.h>
#include <MMMSimoxTools/MMMSimoxTools.h>
#include <Eigen/Eigenvalues>

using namespace MMM;

PrincipalComponentAnalysisMotionSegmenter::PrincipalComponentAnalysisMotionSegmenter(MotionPtr motion, SegmentationType segType) :
    motion(motion),
    segType(segType)
{
    widget = new PrincipalComponentAnalysisMotionSegmenterWidget(motion, segType);
}

std::vector<SegmentationPtr> PrincipalComponentAnalysisMotionSegmenter::segment() {
    return segment(widget->getConfiguration());
}

std::vector<SegmentationPtr> PrincipalComponentAnalysisMotionSegmenter::segment(PrincipalComponentAnalysisMotionSegmenterConfigurationPtr configuration) {
    if (configuration) {
        switch (segType) {
        case SegmentationType::MMM:
            {
                KinematicSensorPtr sensor = motion->getSensorByType<KinematicSensor>(KinematicSensor::TYPE);
                std::vector<float> timesteps = sensor->getTimesteps();
                std::vector<std::string> jointNames = sensor->getJointNames();
                std::vector<int> indexes;
                for(std::size_t i=0; i<jointNames.size(); i++){
                    if (configuration->segmentNames.find(jointNames.at(i)) != configuration->segmentNames.end()) {
                        indexes.push_back(i);
                    }
                }

                Eigen::MatrixXf matrix(timesteps.size(), indexes.size());

                for (unsigned int i = 0; i < timesteps.size(); i++) {
                    KinematicSensorMeasurementPtr measurement = sensor->getDerivedMeasurement(timesteps.at(i));
                    Eigen::VectorXf timestepVector(indexes.size());
                    for (unsigned int j = 0; j < indexes.size(); j++) {
                        timestepVector[j] = measurement->getJointAngles()[indexes.at(j)];
                    }
                    matrix.row(i) = timestepVector;
                }

                return principalComponentAnalysis(matrix, timesteps, configuration->reducedDimensions, configuration->alpha, configuration->keyframeType);

            }
            break;
        case SegmentationType::MMM_MARKER:
            {
                KinematicSensorPtr kSensor = motion->getSensorByType<KinematicSensor>(KinematicSensor::TYPE);
                ModelPoseSensorPtr mpSensor = motion->getSensorByType<ModelPoseSensor>(ModelPoseSensor::TYPE);
                std::vector<std::string> jointNames = kSensor->getJointNames();
                std::vector<float> timesteps = mpSensor->getTimesteps();

                VirtualRobot::RobotPtr robot = MMM::SimoxTools::buildModel(motion->getModel());
                MMM::SimoxTools::updateInertialMatricesFromModels(robot);

                Eigen::MatrixXf matrix(timesteps.size(), configuration->segmentNames.size() * 3);

                for (unsigned int i = 0; i < timesteps.size(); i++) {
                    robot->setGlobalPose(mpSensor->getDerivedMeasurement(timesteps.at(i))->getRootPose());

                    Eigen::VectorXf jointAngles = kSensor->getDerivedMeasurement(timesteps.at(i))->getJointAngles();
                    std::map<std::string, float> jointValues;
                    for (unsigned int j = 0; j < jointNames.size(); j++) {
                        jointValues[jointNames.at(j)] = jointAngles(j);
                    }
                    robot->setJointValues(jointValues);

                    int j = 0;
                    for (std::string segmentName : configuration->segmentNames) {
                        VirtualRobot::SensorPtr sensor = robot->getSensor(segmentName);
                        Eigen::Vector3f position = sensor->getGlobalPose().col(3).head<3>();
                        for (int k = 0; k < 3; k++) matrix(i,3*j + k) = position(k);
                        j++;
                    }
                }

                return principalComponentAnalysis(matrix, timesteps, configuration->reducedDimensions, configuration->alpha, configuration->keyframeType);
            }
            break;
        case SegmentationType::C3D_MARKER:
            {
                MoCapMarkerSensorPtr sensor = motion->getSensorByType<MoCapMarkerSensor>(MoCapMarkerSensor::TYPE);
                std::vector<float> timesteps = sensor->getTimesteps();

                if (timesteps.size() == 0) return std::vector<SegmentationPtr>();

                Eigen::MatrixXf matrix(timesteps.size(), configuration->segmentNames.size() * 3);

                for (unsigned int i = 0; i < timesteps.size(); i++) {
                    MoCapMarkerSensorMeasurementPtr measurement = sensor->getDerivedMeasurement(timesteps.at(i));
                    std::map<std::string, Eigen::Vector3f> labeledMarker = measurement->getLabeledMarker();
                    int j = 0;
                    for (std::string segmentName : configuration->segmentNames) {
                        Eigen::Vector3f markerPos = labeledMarker[segmentName];
                        for (int k = 0; k < 3; k++) matrix(i,3*j + k) = markerPos(k);
                        j++;
                    }
                }

                return principalComponentAnalysis(matrix, timesteps, configuration->reducedDimensions, configuration->alpha, configuration->keyframeType);
            }
            break;
        default:
            break;
        }
    }
    return std::vector<SegmentationPtr>();
}

std::vector<SegmentationPtr> PrincipalComponentAnalysisMotionSegmenter::principalComponentAnalysis(Eigen::MatrixXf matrix, const std::vector<float> &timesteps, int reducedDimensions, float alpha, KeyframeType keyframeType) {
    std::vector<SegmentationPtr> segmentation;

    Eigen::MatrixXf centered = matrix.rowwise() - matrix.colwise().mean();

    Eigen::MatrixXf covarianceMatrix = (centered.adjoint() * centered) / float(matrix.rows() - 1);

    Eigen::SelfAdjointEigenSolver<Eigen::MatrixXf> eig(covarianceMatrix);

    Eigen::MatrixXf evecs = eig.eigenvectors();
    Eigen::MatrixXf pcaProjectionMatrix = evecs.rightCols(reducedDimensions);

    Eigen::MatrixXf reducedDim = centered * pcaProjectionMatrix;
    Eigen::MatrixXf help = reducedDim * pcaProjectionMatrix.transpose();
    Eigen::MatrixXf restoredDim = help.rowwise() + matrix.colwise().mean();

    Eigen::MatrixXf diffMatrix = restoredDim - matrix;
    Eigen::MatrixXf squaredMatrix = diffMatrix.array().pow(2.0f);

    float minTimestep = 0.0;
    bool lastTimestepSegmented = false;
    for (unsigned int i = 0; i < timesteps.size(); i++) {
        bool segmented = false;
        if (sqrt(squaredMatrix.row(i).sum()) > alpha) {
            if (!lastTimestepSegmented) minTimestep = timesteps[i];
            lastTimestepSegmented = true;
            segmented = true;
        }
        if (!segmented && lastTimestepSegmented) {
            lastTimestepSegmented = false;
            segmentation.push_back(SegmentationPtr(new Segmentation(minTimestep, timesteps.at(i-1))));
        }
    }
    if (lastTimestepSegmented) segmentation.push_back(SegmentationPtr(new Segmentation(minTimestep, timesteps.at(timesteps.size() - 1), keyframeType)));

    return segmentation;
}

void PrincipalComponentAnalysisMotionSegmenter::setSegmentationType(SegmentationType segType) {
    this->segType = segType;
    widget->setSegmentationType(segType);
}

void PrincipalComponentAnalysisMotionSegmenter::setMotion(MMM::MotionPtr motion) {
    this->motion = motion;
    widget->setMotion(motion);
}

QWidget* PrincipalComponentAnalysisMotionSegmenter::getWidget() {
    return widget;
}

std::string PrincipalComponentAnalysisMotionSegmenter::getName() {
    return NAME;
}
